{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chaper 8 - Intrinsic Curiosity Module\n",
    "#### Deep Reinforcement Learning *in Action*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "from nes_py.wrappers import JoypadSpace #A\n",
    "import gym_super_mario_bros\n",
    "from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT #B\n",
    "env = gym_super_mario_bros.make('SuperMarioBros-v0')\n",
    "env = JoypadSpace(env, COMPLEX_MOVEMENT) #C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "done = True\n",
    "for step in range(2500): #D\n",
    "    if done:\n",
    "        state = env.reset()\n",
    "    state, reward, done, info = env.step(env.action_space.sample())\n",
    "    env.render()\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from skimage.transform import resize #A\n",
    "import numpy as np\n",
    "\n",
    "def downscale_obs(obs, new_size=(42,42), to_gray=True):\n",
    "    if to_gray:\n",
    "        return resize(obs, new_size, anti_aliasing=True).max(axis=2) #B\n",
    "    else:\n",
    "        return resize(obs, new_size, anti_aliasing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(env.render(\"rgb_array\"))\n",
    "plt.imshow(downscale_obs(env.render(\"rgb_array\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from collections import deque\n",
    "\n",
    "def prepare_state(state): #A\n",
    "    return torch.from_numpy(downscale_obs(state, to_gray=True)).float().unsqueeze(dim=0)\n",
    "\n",
    "\n",
    "def prepare_multi_state(state1, state2): #B\n",
    "    state1 = state1.clone()\n",
    "    tmp = torch.from_numpy(downscale_obs(state2, to_gray=True)).float()\n",
    "    state1[0][0] = state1[0][1]\n",
    "    state1[0][1] = state1[0][2]\n",
    "    state1[0][2] = tmp\n",
    "    return state1\n",
    "\n",
    "\n",
    "def prepare_initial_state(state,N=3): #C\n",
    "    state_ = torch.from_numpy(downscale_obs(state, to_gray=True)).float()\n",
    "    tmp = state_.repeat((N,1,1))\n",
    "    return tmp.unsqueeze(dim=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy(qvalues, eps=None): #A\n",
    "    if eps is not None:\n",
    "        if torch.rand(1) < eps:\n",
    "            return torch.randint(low=0,high=7,size=(1,))\n",
    "        else:\n",
    "            return torch.argmax(qvalues)\n",
    "    else:\n",
    "        return torch.multinomial(F.softmax(F.normalize(qvalues)), num_samples=1) #B"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import shuffle\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class ExperienceReplay:\n",
    "    def __init__(self, N=500, batch_size=100):\n",
    "        self.N = N #A\n",
    "        self.batch_size = batch_size #B\n",
    "        self.memory = [] \n",
    "        self.counter = 0\n",
    "        \n",
    "    def add_memory(self, state1, action, reward, state2):\n",
    "        self.counter +=1 \n",
    "        if self.counter % 500 == 0: #C\n",
    "            self.shuffle_memory()\n",
    "            \n",
    "        if len(self.memory) < self.N: #D\n",
    "            self.memory.append( (state1, action, reward, state2) )\n",
    "        else:\n",
    "            rand_index = np.random.randint(0,self.N-1)\n",
    "            self.memory[rand_index] = (state1, action, reward, state2)\n",
    "    \n",
    "    def shuffle_memory(self): #E\n",
    "        shuffle(self.memory)\n",
    "        \n",
    "    def get_batch(self): #F\n",
    "        if len(self.memory) < self.batch_size:\n",
    "            batch_size = len(self.memory)\n",
    "        else:\n",
    "            batch_size = self.batch_size\n",
    "        if len(self.memory) < 1:\n",
    "            print(\"Error: No data in memory.\")\n",
    "            return None\n",
    "        #G\n",
    "        ind = np.random.choice(np.arange(len(self.memory)),batch_size,replace=False)\n",
    "        batch = [self.memory[i] for i in ind] #batch is a list of tuples\n",
    "        state1_batch = torch.stack([x[0].squeeze(dim=0) for x in batch],dim=0)\n",
    "        action_batch = torch.Tensor([x[1] for x in batch]).long()\n",
    "        reward_batch = torch.Tensor([x[2] for x in batch])\n",
    "        state2_batch = torch.stack([x[3].squeeze(dim=0) for x in batch],dim=0)\n",
    "        return state1_batch, action_batch, reward_batch, state2_batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Phi(nn.Module): #A\n",
    "    def __init__(self):\n",
    "        super(Phi, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv2 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "\n",
    "    def forward(self,x):\n",
    "        x = F.normalize(x)\n",
    "        y = F.elu(self.conv1(x))\n",
    "        y = F.elu(self.conv2(y))\n",
    "        y = F.elu(self.conv3(y))\n",
    "        y = F.elu(self.conv4(y)) #size [1, 32, 3, 3] batch, channels, 3 x 3\n",
    "        y = y.flatten(start_dim=1) #size N, 288\n",
    "        return y\n",
    "\n",
    "class Gnet(nn.Module): #B\n",
    "    def __init__(self):\n",
    "        super(Gnet, self).__init__()\n",
    "        self.linear1 = nn.Linear(576,256)\n",
    "        self.linear2 = nn.Linear(256,12)\n",
    "\n",
    "    def forward(self, state1,state2):\n",
    "        x = torch.cat( (state1, state2) ,dim=1)\n",
    "        y = F.relu(self.linear1(x))\n",
    "        y = self.linear2(y)\n",
    "        y = F.softmax(y,dim=1)\n",
    "        return y\n",
    "\n",
    "class Fnet(nn.Module): #C\n",
    "    def __init__(self):\n",
    "        super(Fnet, self).__init__()\n",
    "        self.linear1 = nn.Linear(300,256)\n",
    "        self.linear2 = nn.Linear(256,288)\n",
    "\n",
    "    def forward(self,state,action):\n",
    "        action_ = torch.zeros(action.shape[0],12) #D\n",
    "        indices = torch.stack( (torch.arange(action.shape[0]), action.squeeze()), dim=0)\n",
    "        indices = indices.tolist()\n",
    "        action_[indices] = 1.\n",
    "        x = torch.cat( (state,action_) ,dim=1)\n",
    "        y = F.relu(self.linear1(x))\n",
    "        y = self.linear2(y)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Qnetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Qnetwork, self).__init__()\n",
    "        #in_channels, out_channels, kernel_size, stride=1, padding=0\n",
    "        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv2 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.linear1 = nn.Linear(288,100)\n",
    "        self.linear2 = nn.Linear(100,12)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        x = F.normalize(x)\n",
    "        y = F.elu(self.conv1(x))\n",
    "        y = F.elu(self.conv2(y))\n",
    "        y = F.elu(self.conv3(y))\n",
    "        y = F.elu(self.conv4(y))\n",
    "        y = y.flatten(start_dim=2)\n",
    "        y = y.view(y.shape[0], -1, 32)\n",
    "        y = y.flatten(start_dim=1)\n",
    "        y = F.elu(self.linear1(y))\n",
    "        y = self.linear2(y) #size N, 12\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'batch_size':150,\n",
    "    'beta':0.2,\n",
    "    'lambda':0.1,\n",
    "    'eta': 1.0,\n",
    "    'gamma':0.2,\n",
    "    'max_episode_len':100,\n",
    "    'min_progress':15,\n",
    "    'action_repeats':6,\n",
    "    'frames_per_state':3\n",
    "}\n",
    "\n",
    "replay = ExperienceReplay(N=1000, batch_size=params['batch_size'])\n",
    "Qmodel = Qnetwork()\n",
    "encoder = Phi()\n",
    "forward_model = Fnet()\n",
    "inverse_model = Gnet()\n",
    "forward_loss = nn.MSELoss(reduction='none')\n",
    "inverse_loss = nn.CrossEntropyLoss(reduction='none')\n",
    "qloss = nn.MSELoss()\n",
    "all_model_params = list(Qmodel.parameters()) + list(encoder.parameters()) #A\n",
    "all_model_params += list(forward_model.parameters()) + list(inverse_model.parameters())\n",
    "opt = optim.Adam(lr=0.001, params=all_model_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(q_loss, inverse_loss, forward_loss):\n",
    "    loss_ = (1 - params['beta']) * inverse_loss\n",
    "    loss_ += params['beta'] * forward_loss\n",
    "    loss_ = loss_.sum() / loss_.flatten().shape[0]\n",
    "    loss = loss_ + params['lambda'] * q_loss\n",
    "    return loss\n",
    "\n",
    "def reset_env():\n",
    "    \"\"\"\n",
    "    Reset the environment and return a new initial state\n",
    "    \"\"\"\n",
    "    env.reset()\n",
    "    state1 = prepare_initial_state(env.render('rgb_array'))\n",
    "    return state1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ICM(state1, action, state2, forward_scale=1., inverse_scale=1e4):\n",
    "    state1_hat = encoder(state1) #A\n",
    "    state2_hat = encoder(state2)\n",
    "    state2_hat_pred = forward_model(state1_hat.detach(), action.detach()) #B\n",
    "    forward_pred_err = forward_scale * forward_loss(state2_hat_pred, \\\n",
    "                        state2_hat.detach()).sum(dim=1).unsqueeze(dim=1)\n",
    "    pred_action = inverse_model(state1_hat, state2_hat) #C\n",
    "    inverse_pred_err = inverse_scale * inverse_loss(pred_action, \\\n",
    "                                        action.detach().flatten()).unsqueeze(dim=1)\n",
    "    return forward_pred_err, inverse_pred_err"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def minibatch_train(use_extrinsic=True):\n",
    "    state1_batch, action_batch, reward_batch, state2_batch = replay.get_batch() \n",
    "    action_batch = action_batch.view(action_batch.shape[0],1) #A\n",
    "    reward_batch = reward_batch.view(reward_batch.shape[0],1)\n",
    "    \n",
    "    forward_pred_err, inverse_pred_err = ICM(state1_batch, action_batch, state2_batch) #B\n",
    "    i_reward = (1. / params['eta']) * forward_pred_err #C\n",
    "    reward = i_reward.detach() #D\n",
    "    if use_explicit: #E\n",
    "        reward += reward_batch \n",
    "    qvals = Qmodel(state2_batch) #F\n",
    "    reward += params['gamma'] * torch.max(qvals)\n",
    "    reward_pred = Qmodel(state1_batch)\n",
    "    reward_target = reward_pred.clone()\n",
    "    indices = torch.stack( (torch.arange(action_batch.shape[0]), \\\n",
    "    action_batch.squeeze()), dim=0)\n",
    "    indices = indices.tolist()\n",
    "    reward_target[indices] = reward.squeeze()\n",
    "    q_loss = 1e5 * qloss(F.normalize(reward_pred), F.normalize(reward_target.detach()))\n",
    "    return forward_pred_err, inverse_pred_err, q_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 8.13"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 5000\n",
    "env.reset()\n",
    "state1 = prepare_initial_state(env.render('rgb_array'))\n",
    "eps=0.15\n",
    "losses = []\n",
    "episode_length = 0\n",
    "switch_to_eps_greedy = 1000\n",
    "state_deque = deque(maxlen=params['frames_per_state'])\n",
    "e_reward = 0.\n",
    "last_x_pos = env.env.env._x_position #A\n",
    "ep_lengths = []\n",
    "use_explicit = False\n",
    "for i in range(epochs):\n",
    "    opt.zero_grad()\n",
    "    episode_length += 1\n",
    "    q_val_pred = Qmodel(state1) #B\n",
    "    if i > switch_to_eps_greedy: #C\n",
    "        action = int(policy(q_val_pred,eps))\n",
    "    else:\n",
    "        action = int(policy(q_val_pred))\n",
    "    for j in range(params['action_repeats']): #D\n",
    "        state2, e_reward_, done, info = env.step(action)\n",
    "        last_x_pos = info['x_pos']\n",
    "        if done:\n",
    "            state1 = reset_env()\n",
    "            break\n",
    "        e_reward += e_reward_\n",
    "        state_deque.append(prepare_state(state2))\n",
    "    state2 = torch.stack(list(state_deque),dim=1) #E\n",
    "    replay.add_memory(state1, action, e_reward, state2) #F\n",
    "    e_reward = 0\n",
    "    if episode_length > params['max_episode_len']: #G\n",
    "        if (info['x_pos'] - last_x_pos) < params['min_progress']:\n",
    "            done = True\n",
    "        else:\n",
    "            last_x_pos = info['x_pos']\n",
    "    if done:\n",
    "        ep_lengths.append(info['x_pos'])\n",
    "        state1 = reset_env()\n",
    "        last_x_pos = env.env.env._x_position\n",
    "        episode_length = 0\n",
    "    else:\n",
    "        state1 = state2\n",
    "    if len(replay.memory) < params['batch_size']:\n",
    "        continue\n",
    "    forward_pred_err, inverse_pred_err, q_loss = minibatch_train(use_extrinsic=False) #H\n",
    "    loss = loss_fn(q_loss, forward_pred_err, inverse_pred_err) #I\n",
    "    loss_list = (q_loss.mean(), forward_pred_err.flatten().mean(),\\\n",
    "    inverse_pred_err.flatten().mean())\n",
    "    losses.append(loss_list)\n",
    "    loss.backward()\n",
    "    opt.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Test Trained Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "done = True\n",
    "state_deque = deque(maxlen=params['frames_per_state'])\n",
    "for step in range(5000):\n",
    "    if done:\n",
    "        env.reset()\n",
    "        state1 = prepare_initial_state(env.render('rgb_array'))\n",
    "    q_val_pred = Qmodel(state1)\n",
    "    action = int(policy(q_val_pred,eps))\n",
    "    state2, reward, done, info = env.step(action)\n",
    "    state2 = prepare_multi_state(state1,state2)\n",
    "    state1=state2\n",
    "    env.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:deeprl]",
   "language": "python",
   "name": "conda-env-deeprl-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
